"""Reader for training log.

See lib/Analysis/TrainingLogger.cpp for a description of the format.
"""
import ctypes
import dataclasses
import io
import json
import math
import sys
from typing import List, Optional

_element_types = {
    'float': ctypes.c_float,
    'double': ctypes.c_double,
    'int8_t': ctypes.c_int8,
    'uint8_t': ctypes.c_uint8,
    'int16_t': ctypes.c_int16,
    'uint16_t': ctypes.c_uint16,
    'int32_t': ctypes.c_int32,
    'uint32_t': ctypes.c_uint32,
    'int64_t': ctypes.c_int64,
    'uint64_t': ctypes.c_uint64
}


@dataclasses.dataclass(frozen=True)
class TensorSpec:
  name: str
  port: int
  shape: List[int]
  element_type: type

  @staticmethod
  def from_dict(d: dict):
    name = d['name']
    port = d['port']
    shape = [int(e) for e in d['shape']]
    element_type_str = d['type']
    if element_type_str not in _element_types:
      raise ValueError(f'uknown type: {element_type_str}')
    return TensorSpec(
        name=name,
        port=port,
        shape=shape,
        element_type=_element_types[element_type_str])


class TensorValue:

  def __init__(self, spec: TensorSpec, buffer: bytes):
    self._spec = spec
    self._buffer = buffer
    self._view = ctypes.cast(self._buffer,
                             ctypes.POINTER(self._spec.element_type))
    self._len = math.prod(self._spec.shape)

  def spec(self) -> TensorSpec:
    return self._spec

  def __len__(self) -> int:
    return self._len

  def __getitem__(self, index):
    if index < 0 or index >= self._len:
      raise IndexError(f'Index {index} out of range [0..{self._len})')
    return self._view[index]


def read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue:
  size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type)
  data = fs.read(size)
  return TensorValue(ts, data)


def pretty_print_tensor_value(tv: TensorValue):
  print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}')


def read_header(f: io.BufferedReader):
  header = json.loads(f.readline())
  tensor_specs = [TensorSpec.from_dict(ts) for ts in header['features']]
  score_spec = TensorSpec.from_dict(
      header['score']) if 'score' in header else None
  advice_spec = TensorSpec.from_dict(
      header['advice']) if 'advice' in header else None
  return tensor_specs, score_spec, advice_spec


def read_one_observation(context: Optional[str], event_str: str,
                         f: io.BufferedReader, tensor_specs: List[TensorSpec],
                         score_spec: Optional[TensorSpec]):
  event = json.loads(event_str)
  if 'context' in event:
    context = event['context']
    event = json.loads(f.readline())
  observation_id = int(event['observation'])
  features = []
  for ts in tensor_specs:
    features.append(read_tensor(f, ts))
  f.readline()
  score = None
  if score_spec is not None:
    score_header = json.loads(f.readline())
    assert int(score_header['outcome']) == observation_id
    score = read_tensor(f, score_spec)
    f.readline()
  return context, observation_id, features, score


def read_stream(fname: str):
  with io.BufferedReader(io.FileIO(fname, 'rb')) as f:
    tensor_specs, score_spec, _ = read_header(f)
    context = None
    while True:
      event_str = f.readline()
      if not event_str:
        break
      context, observation_id, features, score = read_one_observation(
          context, event_str, f, tensor_specs, score_spec)
      yield context, observation_id, features, score


def main(args):
  last_context = None
  for ctx, obs_id, features, score in read_stream(args[1]):
    if last_context != ctx:
      print(f'context: {ctx}')
      last_context = ctx
    print(f'observation: {obs_id}')
    for fv in features:
      pretty_print_tensor_value(fv)
    if score:
      pretty_print_tensor_value(score)


if __name__ == '__main__':
  main(sys.argv)
